from __future__ import division, print_function
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import sys
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import time
import mi_estimator
import warnings
# To prevent PIL warnings.
warnings.filterwarnings("ignore")

from torchmetrics import Accuracy
from torchvision import models
from torch.autograd import Variable
from tqdm import tqdm
import torchnet as tnt
from art.estimators.classification import PyTorchClassifier
import art.attacks.evasion

import cifar10models
import cifar100models
import data
import utils
from manager import Manager
from autoattack.autoattack import AutoAttack



###General flags
FLAGS = argparse.ArgumentParser()

FLAGS.add_argument('--network', choices=['AlexNet', 'VGG16', 'Resnet50'], help='Architectures')
FLAGS.add_argument('--attacktype', choices=['FGSM', 'C&W', 'PGD', 'I-FGSM', 'AutoAttack'], help='Type of adversarial attack used')
FLAGS.add_argument('--num_fb_layers', type=int, default=4, help='Number of layers allocated to subnetwork fb')
FLAGS.add_argument('--dataset', choices=['CIFAR10', 'CIFAR100', 'MNIST', 'Imagenette2'], help='Dataset used for training')
FLAGS.add_argument('--batchsize', type=int, default=512, help='Batch size')
FLAGS.add_argument('--epochs', type=int, default=20, help='Number of training epochs')
FLAGS.add_argument('--eps', type=int, default=8, help='Perturbation magnitude, to be divided by 255')
FLAGS.add_argument('--eps_step', type=int, default=1, help='Perturbation step size for each iteration, to be divided by 255')
FLAGS.add_argument('--attack_iterations', type=int, default=-1, help='Number of iterations in attack')
FLAGS.add_argument('--learning_rate', type=float, default=1e-2, help='Learning rate')
FLAGS.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay')
FLAGS.add_argument('--cuda', action='store_true', default=True, help='use CUDA')

   
######################################################################################################################################################################
###
###     Main function
###
######################################################################################################################################################################



def main():
    args = FLAGS.parse_args()
    torch.cuda.set_device(0)

    print("##############################################################\n\
            Network: ", args.network ,"\n\
            Attack:", args.attacktype ,"\n\
            Dataset:", args.dataset ,"\n\
            Epsilon:", args.eps ,"\n\
            Attack Iterations:", args.attack_iterations ,"\n\
            Epochs:", args.epochs , "\n\
            num_fb_layers:", args.num_fb_layers, "\n")
    print("##############################################################")
#########################################################################################
###    Prepare Data and Loaders
#########################################################################################


    ### Prepares the data for the chosen dataset as well as dataloaders to pass to the Manager for training and evaluation
    datavar = data.Dataset(("../data/" + args.dataset), args.dataset)
    
    traindata = datavar.train_dataset()
    x_train, y_train = zip(*traindata)
    x_train_np = np.asarray([item.detach().cpu().numpy() for item in x_train])
    y_train = np.asarray(y_train)
    y_train = y_train.astype("int64")

    testdata = datavar.test_dataset()
    x_test, y_test = zip(*testdata)
    x_test_np = np.asarray([item.detach().cpu().numpy() for item in x_test])
    y_test = np.asarray(y_test)
    y_test = y_test.astype("int64")

    trainloader = datavar.train_dataloader()


    if args.dataset != "Imagenette2":
        ### For cifar10 and cifar100 the validation dataset is prepared directly within data.Dataset()
        valdata = datavar.val_dataset()
        x_val, y_val = zip(*valdata)
        x_val_np = np.asarray([item.detach().cpu().numpy() for item in x_val])
        y_val = np.asarray(y_val)
        y_val = y_val.astype("int64")

        valloader = datavar.val_dataloader()
        testloader = datavar.test_dataloader()

    else:
        ### Due to the format of neuralmagic's sparseml implementation of Imagenette, 
        ###     we indirectly convert in and out of numpy to split off a validation subset
        x_val_np = x_test_np[:np.ceil(len(x_test_np)/2).astype(int)]
        y_val = y_test[:np.ceil(len(y_test)/2).astype(int)]
        x_test_np = x_test_np[np.ceil(len(x_test_np)/2).astype(int):]
        y_test = y_test[np.ceil(len(y_test)/2).astype(int):]
        
        valdata = []
        for i in range(len(x_val_np)):
          valdata.append([x_val_np[i], y_val[i]])
        
        testdata = []
        for i in range(len(x_test_np)):
          testdata.append([x_test_np[i], y_test[i]])
        
        valloader = torch.utils.data.DataLoader(valdata, batch_size=args.batchsize, num_workers=1, shuffle=False, drop_last=True, pin_memory=True)
        testloader = torch.utils.data.DataLoader(testdata, batch_size=args.batchsize, num_workers=1, shuffle=False, drop_last=True, pin_memory=True)
 
#########################################################################################
###    Prepare The Model
#########################################################################################


    ### Dictionaries of the indices of all trainable layers (e.g. Conv2D and Linear) to index by num_fb_layers
    layerdict_CIFAR = {
        "AlexNet":[2,5,8,10,12,15],
        "VGG16":[2,5,9,12,16,19,22,26,29,32,36,39,42,48,51,54],
        "Resnet50":[1,7,9,11,15,18,20,22,26,28,30,35,37,39,43,46,48,50,54,56,58,62,64,66,71,73,75,79,82,84,86,90,92,94,98,100,102,106,108,110,114,116,118,123,125,127,131,134,136,138,142,144,146,150] 
    }
    
    layerdict_Imagenette2 = {
        "AlexNet":[2,5,8,10,12,18,21,23],
        "VGG16":[2,5,9,12,16,19,22,26,29,32,36,39,42,48,51,54],
        "Resnet50":[1,7,9,11,15,18,20,22,26,28,30,35,37,39,43,46,48,50,54,56,58,62,64,66,71,73,75,79,82,84,86,90,92,94,98,100,102,106,108,110,114,116,118,123,125,127,131,134,136,138,142,144,146,150] 
    }    
    
    ### Load the appropriate model pretrained on non-adversarial data
    if args.dataset=="Imagenette2":
        if args.network == 'AlexNet':
            model = models.AlexNet(num_classes=10)
            state_dict = torch.load('../state_dicts/Imagenette2/AlexNet/pretrain_dict.zip')
            model.load_state_dict(state_dict)
        elif args.network == 'VGG16':
            model = models.vgg16_bn(num_classes=10)
            state_dict = torch.load('../state_dicts/Imagenette2/VGG16/pretrain_dict.zip')
            model.load_state_dict(state_dict)
        elif args.network == 'Resnet50':
            model = models.resnet50(num_classes=10)
            state_dict = torch.load('../state_dicts/Imagenette2/Resnet50/pretrain_dict.zip')
            model.load_state_dict(state_dict)
    elif args.dataset=="CIFAR10":
        if args.network == 'AlexNet':
            model = cifar10models.AlexNet()
            state_dict = torch.load('../state_dicts/' +  args.dataset + '/model_best.pth.tar')['state_dict']
            # Rename state_dict's keys to match those of model
            state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
            model.load_state_dict(state_dict)
        elif args.network == 'VGG16':
            model = cifar10models.vgg16_bn(pretrained=True)
        elif args.network == 'Resnet50':
            model = cifar10models.resnet50(pretrained=True)
    elif args.dataset=="CIFAR100":
        if args.network == 'AlexNet':
            model = cifar100models.AlexNet(num_classes=100)
            state_dict = torch.load('../state_dicts/CIFAR100/AlexNet/pretrain_dict.zip')
            model.load_state_dict(state_dict)
        elif args.network == 'VGG16':
            model = cifar100models.vgg16_bn(pretrained=True)
        elif args.network == 'Resnet50':
            model = cifar100models.resnet50(pretrained=True)


    ### Get the number of layers and indices of weight layers in the network for the given dataset
    if args.dataset=="Imagenette2":
        num_layers = len(layerdict_Imagenette2[args.network])
        weight_layers = layerdict_Imagenette2[args.network]
    elif args.dataset=="CIFAR10" or args.dataset=="CIFAR100":
        num_layers = len(layerdict_CIFAR[args.network])
        weight_layers = layerdict_CIFAR[args.network]
        
    ### Get index of the first layer in f_b
    f_b_index = num_layers - args.num_fb_layers
    if f_b_index > 0:
        f_b_start = weight_layers[f_b_index] 
    print(f'Set-Up: {f_b_index}, {f_b_start}, {args.attacktype}, {args.network}, {args.eps}')
        
    lossfn = torch.nn.CrossEntropyLoss()
    accuracy = Accuracy()    
    optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, momentum=0.9, nesterov=True)
    
    ### If no iteration value is given, calculate one based on the suggestions in https://www.taylorfrancis.com/chapters/edit/10.1201/9781351251389-8/adversarial-examples-physical-world-alexey-kurakin-ian-goodfellow-samy-bengio as  min(4+ϵ/α,1.25⋅ϵ/α)
    attack_iterations = min(4+(args.eps/args.eps_step), 1.25*(args.eps/args.eps_step))
    args.attack_iterations = np.ceil(attack_iterations).astype(int)
    model.cuda()

#########################################################################################
###    Prepare Adversarial Data
#########################################################################################
    advdatapath = "../attacks/" + args.network + "/" + args.attacktype +"/" + args.dataset +"/" + str(args.eps) 

    ### We attack the entire network and save it for future experiments to reduce runtime.
    ### If an attack has already been generated for the given dataset, model, and attack then the data will be loaded instead of generated
    if os.path.exists((advdatapath + "/x_train_adv.npy")):
        x_train_adv = np.load((advdatapath + "/x_train_adv.npy"))
        print('x_train_adv loaded.')
        x_test_adv = np.load((advdatapath + "/x_test_adv.npy"))
        print('x_test_adv loaded.')
        x_val_adv = np.load((advdatapath + "/x_val_adv.npy"))
        print('x_val_adv loaded.')        
        y_train = np.load((advdatapath + "/y_train.npy"))
        y_test = np.load((advdatapath + "/y_test.npy"))
        y_val = np.load((advdatapath + "/y_val.npy"))
    else:
        ### Otherwise generate the new attacked data
        if args.dataset == "CIFAR10":
            classifier = PyTorchClassifier(model=model, clip_values=(0.0, 1.0), loss=lossfn, optimizer=optimizer, input_shape=(3, 32, 32), nb_classes=10)
        elif args.dataset == "CIFAR100":
            classifier = PyTorchClassifier(model=model, clip_values=(0.0, 1.0), loss=lossfn, optimizer=optimizer, input_shape=(3, 32, 32), nb_classes=100)
        elif args.dataset == "Imagenette2":
            classifier = PyTorchClassifier(model=model, clip_values=(0.0, 1.0), loss=lossfn, optimizer=optimizer, input_shape=(3, 224, 224), nb_classes=10)
            
        if args.attacktype == 'FGSM':
          attack = art.attacks.evasion.FastGradientMethod(estimator=classifier, eps=args.eps/255)
        elif args.attacktype == 'I-FGSM':
          attack = art.attacks.evasion.BasicIterativeMethod(estimator=classifier, eps=args.eps/255, eps_step=args.eps_step, max_iter=args.attack_iterations)
        elif args.attacktype == 'PGD':
          attack = art.attacks.evasion.ProjectedGradientDescentPyTorch(estimator=classifier, eps=args.eps/255, eps_step=args.eps_step, max_iter=args.attack_iterations)
        elif args.attacktype == 'C&W':
          attack = art.attacks.evasion.CarliniL2Method(classifier=classifier)
        elif args.attacktype == 'AutoAttack':
          attack = AutoAttack(model.cuda(), norm='Linf', eps=args.eps/255)
          attack.apgd.n_restarts = 1
                
        if args.attacktype == 'AutoAttack':
            x_train_adv = attack.run_standard_evaluation(torch.from_numpy(x_train_np), torch.from_numpy(y_train))
            x_train_adv = x_train_adv.detach().cpu().numpy()
        else:
            x_train_adv = attack.generate(x_train_np, y_train)
        print('x_train_adv created.')

        os.makedirs(advdatapath,exist_ok=True)
        np.save((advdatapath + "/x_train_adv.npy"), x_train_adv)
        np.save((advdatapath + "/y_train.npy"), y_train)
        
        if args.attacktype == 'AutoAttack':
            x_val_adv = attack.run_standard_evaluation(torch.from_numpy(x_val_np), torch.from_numpy(y_val))
            x_val_adv = x_val_adv.detach().cpu().numpy()
        else:
            x_val_adv = attack.generate(x_val_np, y_val)
        print('x_val_adv created.')

        os.makedirs(advdatapath,exist_ok=True)
        np.save((advdatapath + "/x_val_adv.npy"), x_val_adv)
        np.save((advdatapath + "/y_val.npy"), y_val)
        
        
        if args.attacktype == 'AutoAttack':
            x_test_adv = attack.run_standard_evaluation(torch.from_numpy(x_test_np), torch.from_numpy(y_test))
            x_test_adv = x_test_adv.detach().cpu().numpy()
        else:
            x_test_adv = attack.generate(x_test_np, y_test)
        print('x_test_adv created.')

        os.makedirs(advdatapath,exist_ok=True)
        np.save((advdatapath + "/x_test_adv.npy"), x_test_adv)
        np.save((advdatapath + "/y_test.npy"), y_test)


    x_train_adv = torch.from_numpy(x_train_adv).float()
    x_val_adv = torch.from_numpy(x_val_adv).float()
    x_test_adv = torch.from_numpy(x_test_adv).float()

    print("Type of x_train_adv is: ", type(x_train_adv))


    ### Produce the dataloaders for the adversarial data
    traindata_adv = []
    for i in range(len(x_train_adv)):
      traindata_adv.append([x_train_adv[i], y_train[i]])

    valdata_adv = []
    for i in range(len(x_val_adv)):
      valdata_adv.append([x_val_adv[i], y_val[i]])

    testdata_adv = []
    for i in range(len(x_test_adv)):
      testdata_adv.append([x_test_adv[i], y_test[i]])
    

    advtrainloader = torch.utils.data.DataLoader(traindata_adv, batch_size=args.batchsize, num_workers=1, shuffle=True, drop_last=True, pin_memory=True)
    advvalloader = torch.utils.data.DataLoader(valdata_adv, batch_size=args.batchsize, num_workers=1, shuffle=False, drop_last=True, pin_memory=True)
    advtestloader = torch.utils.data.DataLoader(testdata_adv, batch_size=args.batchsize, num_workers=1, shuffle=False, drop_last=True, pin_memory=True)
    


#########################################################################################
###    Train on Adversarial Data
#########################################################################################

    ### manager with pretrained normal model
    manager = Manager(args, model, trainloader, valloader, testloader, advtrainloader, advvalloader, advtestloader)


    # Calculate the pretrained models non-robust accuracy on regular data
    pretrained_acc = manager.eval(adversarial=False, data="Test")  
    pretrained_acc = 100 - pretrained_acc[0]
    print("Acc of (f_a,f_b) on (X,Y):", pretrained_acc)


    # Calculate the pretrained models non-robust accuracy on adversarial data
    pretrained_adv_acc = manager.eval(adversarial=True, data="Test")  
    pretrained_adv_acc = 100 - pretrained_adv_acc[0]
    print("Acc of (f_a,f_b) on (X_adv,Y):", pretrained_adv_acc)

    ### Store the pretrained network state so that we can reload f_b later
    root_save_path = './saves/' + args.network + "/" + args.attacktype + "/" + args.dataset +"/" + str(args.eps) + "/" + str(args.epochs) + "/"
    normal_path = root_save_path + 'normal_' + str(f_b_start)
    os.makedirs(root_save_path,exist_ok=True)    
    manager.save_model(normal_path)

    ### Do adversarial training for the model, checkpointing on validation accuracy
    best_val_accuracy = 0
    trt = time.time()
    for g in optimizer.param_groups:
        g['lr'] = 0.01
    best_val_accuracy = manager.train(40, optimizer, save=True, best_val_accuracy=best_val_accuracy, adversarial=True)   
    print("changing lr to 0.001")
    for g in optimizer.param_groups:
        g['lr'] = 0.001
    best_val_accuracy = manager.train(40, optimizer, save=True, best_val_accuracy=best_val_accuracy, adversarial=True)   
    print("changing lr to 0.0001")
    for g in optimizer.param_groups:
        g['lr'] = 0.0001
    best_val_accuracy = manager.train(40, optimizer, save=True, best_val_accuracy=best_val_accuracy, adversarial=True)  
    trt = time.time() - trt

    ### Reload the best checkpoint from training based on validation set accuracy
    checkpoint_path = ("./saves/" + args.dataset + "/" + args.network + "/" + args.attacktype + "/" + str(args.eps) + "/checkpoint_" + str(args.num_fb_layers) + "_" + str(args.epochs))
    checkpoint = torch.load(checkpoint_path)
    manager.load_model(checkpoint)

    ### Store the model state (f_a*,f_b*)
    adv_path = root_save_path + 'adversarial_' + str(f_b_start)
    manager.save_model(adv_path)

    ### Calculate ACC*
    baseline_acc = manager.eval(adversarial=True, data="Test")  
    baseline_acc = 100 - baseline_acc[0]
    print("baseline acc of (f_a*,f_b*) on (X_adv,Y):", baseline_acc)
    
#########################################################################################
###    Prepare Spliced Model
#########################################################################################
    
    ### Get the state dict from (f_a,f_b)
    norm_states = torch.load(normal_path)
    
    
    ### Load only the weights for [f_b_start:output] to get (f_a*,f_b)
    manager.load_fb(norm_states, f_b_start)
    
    ### Store model as starting point of subnetwork training trials 1:T=10
    spliced_path = root_save_path + 'spliced_' + str(f_b_start)
    manager.save_model(spliced_path)
        
    spliced_acc = manager.eval(adversarial=True, data="Test")  
    spliced_acc = 100 - spliced_acc[0]
    print("spliced acc of (f_a*,f_b) on (X_adv,Y):", spliced_acc)

    
    
    
    
#########################################################################################
###    Run Experiment
#########################################################################################
    
    
    # Split x_train_adv into lists of samples with each having the same class
    # Find p(y_n) = (# of y_n in data) / N
    ### This is for calculating MI conditioned on class label
    if args.dataset=="CIFAR100":
        x_split = [x_test_adv[np.where(y_test[:] == j)] for j in range(100)]                                   
        label_counts = np.bincount(y_test[:], minlength=100) / len(y_test)                                     
    else:
        x_split = [x_test_adv[np.where(y_test[:] == j)] for j in range(10)]                                   
        label_counts = np.bincount(y_test[:], minlength=10) / len(y_test)                                     
        

    ### Set our accuracy threshold and dummy values for the connectivities and accuracy
    k = 1e-16
    rho_n, rho_n1, rho_n2, rho_n3 = float('inf'), float('inf'), float('inf'), float('inf')
    max_trn_acc, max_tst_acc = 0, 0
    mi_t = 0
    num_trials = 10
    results= np.zeros((num_trials, 29))
    
    
    for i in range(10):                               # Number of trials
        print('Trial ' + str(i + 1))
        ### This is loading the network with the weights from (f_a*,f_b)
        manager.load_model(torch.load(spliced_path))
        optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, momentum=0.9, nesterov=True)

        ### Prevent updates to f_a* so that only the subnetwork f_b is trained into f_b~
        manager.freeze_fa(f_b_start)

        ### Train f_b for up to 20 epochs or until test accuracy is within k of ACC*
        t = time.time()
        test_acc, val_acc, train_acc, break_epoch = manager.train_subnetwork(args.epochs, optimizer, save=True, target_accuracy=baseline_acc, adversarial=True, delta=k)   
        t = time.time() - t
        
        print("getting estimates of MI")
              
        mi_est_n = manager.get_mi_estimate(weight_layers[num_layers - 2],weight_layers[num_layers - 1], x_split, label_counts)  if f_b_index <= num_layers - 1 else float('inf')
        mi_est_n1 = manager.get_mi_estimate(weight_layers[num_layers - 3],weight_layers[num_layers - 2], x_split, label_counts) if f_b_index <= num_layers - 2 else float('inf')
        mi_est_n2 = manager.get_mi_estimate(weight_layers[num_layers - 4],weight_layers[num_layers - 3], x_split, label_counts) if f_b_index <= num_layers - 3 else float('inf')
        mi_est_n3 = manager.get_mi_estimate(weight_layers[num_layers - 5],weight_layers[num_layers - 4], x_split, label_counts) if f_b_index <= num_layers - 4 else float('inf')
        mi_est_n4 = manager.get_mi_estimate(weight_layers[num_layers - 6],weight_layers[num_layers - 5], x_split, label_counts) if f_b_index <= num_layers - 5 else float('inf')
        mi_est_n5 = manager.get_mi_estimate(weight_layers[num_layers - 7],weight_layers[num_layers - 6], x_split, label_counts) if f_b_index <= num_layers - 6 else float('inf')
        mi_est_n6 = manager.get_mi_estimate(weight_layers[num_layers - 8],weight_layers[num_layers - 7], x_split, label_counts) if f_b_index <= num_layers - 7 else float('inf')
        mi_est_n7 = manager.get_mi_estimate(weight_layers[num_layers - 9],weight_layers[num_layers - 8], x_split, label_counts) if f_b_index <= num_layers - 8 else float('inf')
        mi_est_n8 = manager.get_mi_estimate(weight_layers[num_layers - 10],weight_layers[num_layers - 9], x_split, label_counts) if f_b_index <= num_layers - 9 else float('inf')
        mi_est_n9 = manager.get_mi_estimate(weight_layers[num_layers - 11],weight_layers[num_layers - 10], x_split, label_counts) if f_b_index <= num_layers - 10 else float('inf')
        mi_est_n10 = manager.get_mi_estimate(weight_layers[num_layers - 12],weight_layers[num_layers - 11], x_split, label_counts) if f_b_index <= num_layers - 11 else float('inf')
        mi_est_n11 = manager.get_mi_estimate(weight_layers[num_layers - 13],weight_layers[num_layers - 12], x_split, label_counts) if f_b_index <= num_layers - 12 else float('inf')
        mi_est_n12 = manager.get_mi_estimate(weight_layers[num_layers - 14],weight_layers[num_layers - 13], x_split, label_counts) if f_b_index <= num_layers - 13 else float('inf')
        mi_est_n13 = manager.get_mi_estimate(weight_layers[num_layers - 15],weight_layers[num_layers - 14], x_split, label_counts) if f_b_index <= num_layers - 14 else float('inf')
        mi_est_n14 = manager.get_mi_estimate(weight_layers[num_layers - 16],weight_layers[num_layers - 15], x_split, label_counts) if f_b_index <= num_layers - 15 else float('inf')
        mi_est_n15 = manager.get_mi_estimate(weight_layers[num_layers - 17],weight_layers[num_layers - 16], x_split, label_counts) if f_b_index <= num_layers - 16 else float('inf')
        print('MI estimates: {:.3f}, {:.3f}, {:.3f}, {:.3f}; Time: {:.3f}'.format(
            mi_est_n, mi_est_n1, mi_est_n2, mi_est_n3, t))
      
        # # del model
        results[i,0]=mi_est_n
        results[i,1]=mi_est_n1
        results[i,2]=mi_est_n2
        results[i,3]=mi_est_n3
        results[i,4]=mi_est_n4
        results[i,5]=mi_est_n5
        results[i,6]=mi_est_n6
        results[i,7]=mi_est_n7
        results[i,8]=mi_est_n8
        results[i,9]=mi_est_n9
        results[i,10]=mi_est_n10
        results[i,11]=mi_est_n11
        results[i,12]=mi_est_n12
        results[i,13]=mi_est_n13
        results[i,14]=mi_est_n14
        results[i,15]=mi_est_n15
        results[i,16]=pretrained_acc
        results[i,17]=pretrained_adv_acc
        results[i,18]=baseline_acc
        results[i,19]=spliced_acc
        results[i,20]=train_acc
        results[i,21]=val_acc
        results[i,22]=test_acc
        results[i,23]=break_epoch
        results[i,24]=t
        results[i,25]=trt
        results[i,26]=args.eps
        results[i,27]=args.eps_step
        results[i,28]=args.attack_iterations

    path = "./results/" + args.network + "/" + args.attacktype + "/" + args.dataset + "/" + str(args.eps) + "/" + str(args.num_fb_layers) + "/" + str(args.epochs) 
    os.makedirs(path,exist_ok=True)
    np.save((path + "/results.npy"), results)
    print("##############################################################\n\
        mi_1: ", str(mi_est_n) ,"\n\
        mi_2:", str(mi_est_n1) ,"\n\
        mi_3:", str(mi_est_n2) ,"\n\
        mi_4:", str(mi_est_n3) ,"\n\
        mi_5:", str(mi_est_n4) ,"\n\
        mi_6:", str(mi_est_n5) ,"\n\
        mi_7:", str(mi_est_n6) ,"\n\
        mi_8:", str(mi_est_n7) ,"\n\
        mi_9:", str(mi_est_n8) ,"\n\
        mi_10:", str(mi_est_n9) ,"\n\
        mi_11:", str(mi_est_n10) ,"\n\
        mi_12:", str(mi_est_n11) ,"\n\
        mi_13:", str(mi_est_n12) ,"\n\
        mi_14:", str(mi_est_n13) ,"\n\
        mi_15:", str(mi_est_n14) ,"\n\
        mi_16:", str(mi_est_n15) ,"\n\
        pretrained_acc:", str(pretrained_acc) ,"\n\
        pretrained_adv_acc:", str(pretrained_adv_acc) , "\n\
        baseline_acc:", str(baseline_acc) , "\n\
        spliced_acc:", str(spliced_acc) , "\n\
        train_acc:", str(train_acc) , "\n\
        val_acc:", str(val_acc) , "\n\
        test_acc:", str(test_acc) , "\n\
        break_epoch:", str(break_epoch), "\n")
    print("##############################################################")
if __name__ == '__main__':
    main()
